{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using PyMC3 samplers on PyMC4 models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pymc4 as pm\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create simple pymc4 model\n",
    "@pm.model(auto_name=True)\n",
    "def t_test():\n",
    "    mu = pm.Normal(0, 1)\n",
    "\n",
    "model = t_test.configure()\n",
    "\n",
    "model._forward_context.vars\n",
    "func = model.make_log_prob_function()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create function to evaluate logp and dlogp over array of inputs\n",
    "@tf.function\n",
    "def logp_array(array):\n",
    "    #mu = array[0]\n",
    "    with tf.GradientTape() as tape:\n",
    "        tape.watch(array)\n",
    "        logp = func(array)\n",
    "    grad = tape.gradient(logp, array)\n",
    "    \n",
    "    return logp, grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# As the above function expects TF inputs and outputs, wrap it as PyMC3's samplers want numpy\n",
    "def logp_wrapper(array):\n",
    "    logp, grad = logp_array(tf.convert_to_tensor(array))\n",
    "    return logp.numpy(), grad.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "q = np.ones(3, dtype='float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = (np.ones(3) * 3).astype(q.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([3., 3., 3.], dtype=float32)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pymc4.hmc import HamiltonianMC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "size = 1\n",
    "n_samples = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.random.set_seed(123)\n",
    "np.random.seed(123)\n",
    "hmc = HamiltonianMC(logp_dlogp_func=logp_wrapper, size=size, adapt_step_size=True)\n",
    "curr = np.ones(size, dtype='float32') * .05\n",
    "posterior_samples = []\n",
    "stats = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%time  # NB: uncommenting cell magic %%time will prevent variable from escaping local cell scope\n",
    "\n",
    "for i in range(n_samples):\n",
    "    curr, stat = hmc.step(curr)\n",
    "    posterior_samples.append(curr)\n",
    "    stats.append(stat)\n",
    "    \n",
    "trace = np.array(posterior_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare with `PyMC3`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0414 17:15:43.725103 139715079030592 configdefaults.py:1458] install mkl with `conda install mkl-service`: No module named 'mkl'\n"
     ]
    }
   ],
   "source": [
    "import pymc3 as pm3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "with pm3.Model() as model3:\n",
    "    pm3.Normal('x', 0, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(123)\n",
    "with model3:\n",
    "    hmc3 = pm3.HamiltonianMC(adapt_step_size=True)\n",
    "    \n",
    "point = {'x': np.array(.05)}\n",
    "trace3 = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 260 ms, sys: 185 µs, total: 261 ms\n",
      "Wall time: 259 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "for i in range(n_samples):\n",
    "    point, _ = hmc3.step(point)\n",
    "    trace3.append(point['x'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-0.0015888748, -0.001164150192971153)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(trace), np.mean(trace3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.9673561, 0.9673307273786419)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.std(trace), np.std(trace3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x7f114067d080>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XlclOe99/HPbxZ2EGRRRBQXFMFd1KjZzKppUpulPUmbnLRpm+Z0SU+f9rR9nrR9znl6ek5P03TJaVqPWdrTpGmatEmzGc1aTapRUXHBBYiKIiogyM7AzFzPH5AeYlAGHLhm+b1fL15hZi5nvpiZrzf3fd3XLcYYlFJKRRaH7QBKKaWCT8tdKaUikJa7UkpFIC13pZSKQFruSikVgbTclVIqAmm5K6VUBNJyV0qpCKTlrpRSEchl64UzMjJMXl6erZdXSqmwtH379npjTOZA46yVe15eHiUlJbZeXimlwpKIVAUyTnfLKKVUBNJyV0qpCKTlrpRSEUjLXSmlIpCWu1JKRSAtd6WUikBa7kopFYG03JVSKgJpuSulVASydoaqUuFmyzMPDDhm8ce/HrTnGszzKXU23XJXSqkIpOWulFIRSMtdKaUikJa7UkpFIC13pZSKQFruSikVgbTclVIqAmm5K6VUBNJyV0qpCKTlrpRSESigcheRFSJyUEQqReTb5xm3UER8InJL8CIqpZQarAHLXUScwEPASqAQuE1ECs8x7j+A9cEOqZRSanAC2XJfBFQaYw4ZY7qAp4BV/Yz7CvAnoDaI+ZRSSg1BIOWeAxzrc7u6976/EZEc4EZgdfCiKaWUGqpAlvyVfu4zZ93+GfAtY4xPpL/hvU8kcjdwN8CECRMCzahUSPK1N9F1+gj4OhFfN13uFI7sfofxBcW4YuJsx1NRLpByrwZy+9weD9ScNaYYeKq32DOA60TEa4z5c99Bxpg1wBqA4uLis/+BUCosdDefJPnEZmZ178UhPW9jj3ET29kNz/6ZZhLYn34NyUs+zYz5lyEOnZSmRl4g5b4NyBeRScBx4Fbgk30HGGMmvf+9iPwGeOnsYlcqEnSd2MdFp5+jmQRedF+LN2sWyYnxxDuhpb2djKxspHwds+vXEv/Snyl7pZCmhfey8KpP4HY5bcdXUWTAcjfGeEXky/TMgnECjxljykTknt7HdT+7inzG4K/axCWtr1MiMzg16UbGJnzw45OenMCi6z8HfI62pgZ2rP8vxu9/hKJ37+HAlh9xrOgfWPyRO0mJj7XzM6ioIsbY2TtSXFxsSkpKrLy2UoN18IUfM33H93lVluLKv4JEd/+7Ws6+LJ6/28PB1x4hdfsvyPbVUGnGs3/WNxnlO02sY+DPnl5mT51NRLYbY4oHGqc7A5UaQO3+d5i849/YxBzip191zmLvj8Mdy4zrvkT2fXs5dsUvSHIbbth7L459z7GvzjOMqVW003JX6jy6muvhmU9zyozmTN51xAx1t7nDSe6ldzD226UcXXgfc6WCW049yPaKatp9+jFUwafvKqXOo/K3X2SUr5GqK39JeqL7wp/QFcOEj3yT/fl3U+/K5otdj9G0/y1Oe849hVipodByV+ocTu55k8L69WzI/CTLLr0mqM/tjxlF7fRPsjvlMj4mG5CKV6n36MdRBY++m5Tqj9+H58V/4oRJZ+5t/zw8ryEOOnIvozT1albKZkzFazRowasg0XeSUv2oWPdLJnZVsqfw62Slpw/ra3lylrA79Uqul79Sf2gnfj29TwWBlrtSZ/F1tpK57X52OQq57KYvjMhrdoxbSlncPG7zvcjO6pYReU0V2bTclTrLgZd+TqppouWS7xDrDuQk7iAQoX3i1TRKKh858yRVLXqAVV2YEXrnKhUe/J52ssvWsMM5h6WXXTfoP7/lmQeG/tquOI7mrmLu0d9QenQD/sJLh/xcSumWu1J9HFz7C0abM7Rd9HUcjpHfevam5FKauIzrzQYO1nWM+OuryKHlrlQv091J5u7VlDqKWHLFDdZy+McvxiMx5Na+RbfPby2HCm9a7kr1OvjaI2SY05xZ+I+4nPY+Gn53ImXJl3CllPD6G+us5VDhTctdKQBjSNyxhoMyiWVX32w7DYxbwBmSSNv8Qzq7fbbTqDCk5a4UcLTkZXK9VVRP/0xIrLvud8VyIPUyLjK72LRBrzmvBk/LXSmgbcOD1JlUFlx3l+0of+McM5N24pBtD9uOosKQlruKeg1H9jCjdQu7sm8hNSXZdpy/8btiqcpdxdLOjZRVvGc7jgozOs9dRb0DT36LBcaNJy7zguapD4fca+8l9pE/cOS11RTl3287jgojuuWuolpX2xnmekp4x7mQjGAs6RtkSeNn8l7SAuaeepbGFp33rgKnW+4qqh189WFmiYczGfNJsh2mH1ueeQBf4jSWtm7niTX3kT8hp99xejk+dTbdclfRyxhSyx5nv8ljXMZo22nOyZk+hXpGMbFlh+0oKoxouauodaz0DXK9VexLuggLKw0EThyUx81lkdlDXZvOeVeB0XJXUatx42qaTQKp46bYjjIgk1lArHhprz1kO4oKE1ruKiq1NZxgRsOblI5eSdKQr3o9clzJWVQzhsltpRi9mIcKgJa7ikrl636FW3xkLP8H21ECI8KhxDnM5wAnmzttp1FhQMtdRR3j85Jd+Xt2u2YxY1ax7TgBc2VNxyEGf12F7SgqDGi5q6hTufl5xvpraZn194iE8pHUD3ImpFEheczo3KG7ZtSAtNxV1PFsfph6RjHvmtttRxm06qSZzJAqTjXpCU3q/LTcVVRpOF5JYeu77Bv7MRLiE2zHGbSYjEkA+E9XWk6iQp2Wu4oqh9Y/hAEmXvMl21GGxJWQxhFyyOvcZzuKCnFa7ipq+Lo9TD76J0rjFzNx8nTbcYasKr6Q2aachrYu21FUCNNyV1Fj3xtPMJomTPFnbUe5IJI+FacYOuqrbEdRIUzLXUUN985fU80Y5l5+k+0oFyQ2JYNTjCa7bb/tKCqEabmrqHD84A4KPHs4kvcJXK4wXwxVhPLYmczzl9Hq0bVmVP+03FVUOPHGL/AYN9NXhskZqQPoSssnTrppqjtuO4oKUVruKuJ1tjVRULuWXSmXkzmm//XQw01S2lhaTDyprXq2quqflruKeHvXPUoSHSQsu9t2lOBxODngKqDIuw+fX09XVR+m5a4imzGM3vdb3nPkUbToKttpgqoxOZ8x0kj9mSbbUVQICvMjS0qd25ZnHuDM6Vqu9R3mpaRbqP/TT21HCqqE9Fw4A45GXeNdfZhuuauIllS/k1YTR8q4fNtRgs4dl0gFE5jQedB2FBWCAip3EVkhIgdFpFJEvt3P46tEZLeIlIpIiYhcHPyoSg1OZ2cHC7u3sy1mEbFut+04w+JYfAFFppK62pO2o6gQM2C5i4gTeAhYCRQCt4lI4VnD3gDmGGPmAncBjwQ7qFKD1XHiIDHipWvMXNtRho0/dRJOMVRuft52FBViAtlyXwRUGmMOGWO6gKeAVX0HGGNajfnbCtOJgB6+V1b5fT5mtm1mD/mkjkq1HWfYJKVm0miScFS+ZjuKCjGBlHsOcKzP7ere+z5ARG4UkQPAy/RsvStlzZ6NzzFe6jiWutB2lGElDgcHXDPIb3mXbq/XdhwVQgIp9/4uVfOhLXNjzHPGmALgY8D3+30ikbt798mX1NXVDS6pUoPg3/oIp00KqWPybEcZdk3JUxlNCwd3bLQdRYWQQMq9Gsjtc3s8UHOuwcaYjcAUEcno57E1xphiY0xxZmbmoMMqFYiaqnJmt7/LzrjFOMN9HZkAJGbm4jfCmV0v246iQkgg5b4NyBeRSSISA9wKvNB3gIhMld6LUYrIfCAGOB3ssEoF4vD6XyKAI3um7SgjIiYmjoqYAjJPbrAdRYWQAcvdGOMFvgysB/YDTxtjykTkHhG5p3fYzcBeESmlZ2bN3/U5wKrUiOlob2NGzZ/Ym7iYxMRk23FGTNP45Uz3VVBbc9R2FBUiAprnboxZa4yZZoyZYoz5Qe99q40xq3u//w9jTJExZq4xZokx5p3hDK3Uuexa9yijaca1LDwvozdUmfOvB+DwlhcGGKmihZ6hqiKG8fvJ3PsYRxwTmbHkettxRlRe0RLqSMOpUyJVLy13FTH2blrLFP9h6mfehTii660tDgdH0pYyrXUrXV16bVWl5a4iSPemX3KGZGat/JztKFa4C64lRdopL3nDdhQVAiJ/npiKOFueeeBD951paeHqtk28GX81yS/9l4VU9k296Aa6N32N5r2vwNKVtuMoy3TLXUUE54lSfDhwZc+yHcWapFGjqYgtYsxJPZlJabmrCNDp6eKirs1sc80nPiHRdhyrWnIvZ4r/MCerdY33aKflrsJe64mDJEknLWMW2I5i3dgFHwWg6l1dJTLaabmrsOb3+5nd+lf2MYXUtCzbcaybULCAk2TgPqQHVaOdlrsKa/WnqsmVWqrSLrIdJSSIw8HR0UuZ1laCx9NhO46ySMtdhbXcxnc5ZUaTNjbPdpSQETNjBUnSQcW2121HURZpuauw1XCmkbnmAKWJSxGH03ackJF/0UfoMk5a9q61HUVZpOWuwlbCqe10mBjix5191cfolpicysG42Yyt1SWeopmWuwpLbR2dXNS9ja0xi4iNjbMdJ+S0TVjOJP9Rao4ctB1FWaLlrsKSp2YvsdJN99j5tqOEpHELey5zfHSrrhIZrbTcVdjx+vwUd/yVHVJEckrkXvz6QuROnU2NjCHuiE6JjFZa7irsNJ44RKY0cSp9ke0oIUscDo6lL2Na2w46O9ptx1EWaLmrsGL8fqY2baKKsaRljrcdJ6TFFa4gQTxUbFtvO4qyQMtdhZV9W1+jgCPsS1qGOMR2nJA2bfF1eIybtr2v2I6iLNByV2Gl8+2HaDIJJOdMsx0l5MUnJnMgfi45dW/bjqIs0HJXYaPmyEHmtm6kJG4Jbpfbdpyw0DHxCnJNDccPldmOokaYlrsKG1Xrfo5BkOw5tqOEjZyFPatEVuuUyKij5a7CQltLE0Un/8zu5EtITEyyHSds5E6dyTEZR7xOiYw6epk9FTL6u3ze++qryvgIbdQmFpA2gpkiwfGMi5lb+xyd7a3EJeg/jNFCt9xVyPP7DTNb/spB8khL1zXbByu+aAVx0k35Fp01E010y12FvIa64yyRk6xNvY100emP/Tnfbz3dXh/tJpbTm34Lyz8+gqmUTbrlrkLe2NNbqDOppGZPth0lLLldTsoc08nv2g/G2I6jRoiWuwppTc3NLDBl7ExYglPXbB+y2sR8xksd1ZW7bUdRI0TLXYU056lddBsn7uwi21HCWkx6HgDHt+qFs6OFlrsKWV3d3Sz0bGGbax4J8Qm244S1lKREKhlPUtVrtqOoEaLlrkJW84lKUqSdpowFtqNEhMOxM5ju2UvzmXrbUdQI0NkyKiQZv6GgZTOV5JI2Wqc/BkN32hRcJ1/jpWefgKKbzjnuk4snjGAqNVy03FVIajh9iouoZm3yx0nX1R+DIjUtg4YTyUw8/jKuZN+5By7++siFUsNGd8uokJR2ejtNJoGUcfm2o0QMp8NBmauIIm8Zfr/fdhw1zLTcVchpa2+n2FtKSexi3C795TKYmpLzSZU2zjTU2o6ihpmWuwo53Sf24sSPGaurPwZbckYOXcaJ+8wh21HUMNNyVyHF5/Mxv2MzOxxFJCen2I4TceJiYyhzTGOKZ5/tKGqYabmrkNJw8giZ0kTt6IW2o0SsmoTpTKKG1tYW21HUMNJyVyElt2k7NSaDtCy9+PVwcadPAqCr/rDlJGo4abmrkNHc2sIcc4A9CYtw6PTHYTMqOZnDZhzZ7QdtR1HDKKByF5EVInJQRCpF5Nv9PP4pEdnd+7VJRPRImBo0c3IfPiPEjJlhO0rEq4ydQaG/nK6uLttR1DAZsNxFxAk8BKwECoHbRKTwrGGHgcuMMbOB7wNrgh1URTZvdzdzOrdS6iwiITHRdpyI15U6lRjx0VJfbTuKGiaBbLkvAiqNMYeMMV3AU8CqvgOMMZuMMY29N98FdIepGpS9G59ljDRyKnWe7ShRITU9kyaTSEpLpe0oapgEUu45wLE+t6t77zuXzwJ6PS81KP7tj9Ngkkkdo+uajASXw8FeVyGF3Xq2aqQKpNz7O7LV7+VcRGQ5PeX+rXM8freIlIhISV1dXeApVURrqDvBzLZNlMYW6wU5RlBD0jRGSwtNjbpKZCQKpNyrgdw+t8cDNWcPEpHZwCPAKmPM6f6eyBizxhhTbIwpzszMHEpeFYEOvvHfxIgPb+bZh3LUcErJHI/XOHA26tmqkSiQct8G5IvIJBGJAW4FXug7QEQmAM8CdxhjyoMfU0WytMrnOOzMY1Rquu0oUSUuNpYyRz6T9WzViDRguRtjvMCXgfXAfuBpY0yZiNwjIvf0DvsekA78UkRKRaRk2BKriFJVvpsC7wFqJ33MdpSodDy+gClU09bWajuKCrKAltwzxqwF1p513+o+338O+Fxwo6locHzjb8g1wpTln+a9Tc/ajhN1nKPzoB089VUkJup1aiOJnqGqrPH7/Ew8/hL74ueRkTPJdpyoNCplFEfNGMa0HbAdRQWZLpatht2WZx7o9/76+lo+Yk6xO3Y5becYo4aXCJTHFnKp5222ertxu9y2I6kg0S13ZU18w346jZvkrDzbUaKaZ9QUYsRLc91x21FUEGm5Kyt8fj+zukrZ5ZyJOybGdpyolpY+hmaTQHKzTnSLJFruyorT9afIlCYaRulBPNucTie7XTOZ2b1Xz1aNIFruyorkxv20mVhSsnS5gVBwJnk6adJK0+mTtqOoINFyVyPO6/Mzp7uU3a5ZuPQC2CEhJSuXTuMm9owuJBYptNzViGuoO9GzlZiqyw2Eili3i92OQgo8ezC6ayYiaLmrEZdyZh8tJp6UDF0ZOpTUJheQLac5vGeT7SgqCLTc1Yjy+vzM8e5hj2sWTt0lE1ISMifiM0Ldtj/ZjqKCQMtdjaiG+pOkSitNqQW2o6izJMbFsdcxjbE1r9uOooJAy12NqMQzB2k3MYzSXTIh6VhCIRP9Rznx3h7bUdQF0nJXI8bnN8zq3s1eZ5HukglRsRmTATi2+RnLSdSF0nJXI+Z0Qx2Z0sTplBm2o6hzSElKpMI5hdSq9bajqAuk5a5GTFxDOV3GRXKmnrgUympzrmZa9wEaTlbZjqIugJa7GhF+v6Goazd7nDN0LZkQN2bRzQAceudpy0nUhdByVyOiobGBcVJPXZLukgl1UwqLOSrZxFe+YjuKugBa7mpEuBor8BoHiZkTbUdRAxCHg2NZVzKto5TWpn6vda/CgJa7GnbGwLTOPexzTCMmLt52HBWAUfNvwi0+Kjb8wXYUNURa7mrYNTQ1MUlOcCJRd8mEixnFy6khC/d+va5tuNJyV8NOTvesNBiXNdlyEhUop9PBobHXUtC+nZaGE7bjqCHQclfDbnLnHvYxhbj4RNtR1CBkLL4Nl/ipeOtJ21HUEGi5q2F1pKKM6RzlWKJecSncTJ+zhMMynoTy52xHUUOg5a6G1fHNPXOlYzKnWE6iBkscDqpzrmNa514aTx6xHUcNkpa7GlZpR9dTwQQSEpNtR1FDkL3sdhxiOPSXx21HUYOk5a6GzfGjhyj07udI/EzbUdQQTSmYzQFHPqMrddZMuNFyV8PmyF97dsk4MnSXTLgSEWqn3sIk7yGq971rO44aBC13NWxSDq/lqGM8SSlptqOoCzDjqs/Qadyc2vCI7ShqELTc1bCoO3WcGZ49nBx3te0o6gJlZo2hNOlipp56BV9Xh+04KkBa7mpYVLz9R1ziJ2vxJ2xHUUHgmHcHo2jlwF+esh1FBUjLXQ2L+MqXOSFZTCy6yHYUFQRzL1vFCTKQ0idsR1EB0nJXQddQX0tRRwnVY69CHPoWiwQxbhcHsz9KQdt2GqoP2o6jAqCfPBV0Bzc8RYz4SF98q+0oKogmXvMlfDg4+srPbEdRAdByV0EXX/48NTKGSbMvsR1FBdGkSVMpSbyUqcf/THd7k+04agBa7iqo6mprKOrcSfW4a3WXTASKWfZFkmhn/7o1tqOoAeinTwVVxV9+j1t8jFnySdtR1DCYt+Rq9jmmkVH2GPj9tuOo89ByV0GVVPkC1Y5xOksmQjkcQm3hZxjnq+HQZl0tMpRpuaugqT1xjCLPLmpyVoKI7ThqmCxY+WlqyICN9/dcQ1GFJC13FTQVb/43TjFkL/uU7ShqGCUnJrBv6ueZ7NlP1dYXbMdR5+AKZJCIrAB+DjiBR4wxPzzr8QLg18B84D5jzI+DHVSFNmMMWe89S6VrKlMLFtiOoy5Eya8HHLIwL4PjlZn43vx3WPRR/U0tBA245S4iTuAhYCVQCNwmIoVnDWsA7gW01KNU5e4t5PvfozH/FttR1AgYFedk//tb71uetx1H9SOQLfdFQKUx5hCAiDwFrAL2vT/AGFML1IrIR4YlpQp5de/8monGyfSrPmM7irpAWw43BDRu4ce+TPWPH8a8+YOerXed+hpSAvm/kQMc63O7uvc+pQDo8niYXvcKZclLSUkfazuOGiGjkhLZX3AveV3lVL6m895DTSDl3t/OtCEdIheRu0WkRERK6urqhvIUKgTt2fAs6TThnK8HUqPNJTd9kT2O6aS/++942xptx1F9BFLu1UBun9vjgZqhvJgxZo0xptgYU5yZmTmUp1AhSEofp4EUCi++yXYUNcLiYly0LP83RvmbOPj0d23HUX0EUu7bgHwRmSQiMcCtgM5/UgBUHz7AnLZNVOTchCsm1nYcZcGSi69kQ9IKplU9SePhUttxVK8By90Y4wW+DKwH9gNPG2PKROQeEbkHQETGikg18L+A74hItYikDGdwFRqOrnsQgzD5unttR1GWiAgTP/FDWkwCLU99DuPtsh1JEeBJTMaYtcaYacaYKcaYH/Tet9oYs7r3+5PGmPHGmBRjTGrv983DGVzZ19bSxMxTf2Z3yqVk5uhFsKPZ5Il5bJv5PSZ4Kjjwx3+xHUehZ6iqC7Bn3SOk0EbCJV+yHUWFgKtu/jwbYi9n6oFfcbpiq+04UU/LXQ2J8fsZs/83VDqnML34KttxVAhwOoSJt/+CRpOC5w+fwdehv7zbpOWuhmTna08yyX+UM7M/q+u2q7/Jy81l75IHGNN9nIpHP6sLi1kU0NoySvXl9/lI2/Ijjsk4ujrb2fLMA7YjqRCy/NqbeOXQZq6rfZjyl3/GtOu/ZjtSVNJyV4O2c91jLPBXUVL8Y5z6q3fUOt8/6kmjs9lcO5vibd/nRO5ssudcOYLJFOhuGTVI3u4uskp+wmHHROavvMt2HBWiYp3QOOl6qski4bk7aao+YDtS1NFyV4NS+sJD5JoaGhb9Ew6n03YcFcLSE1zszPkkfgNtv74JT8tp25Giipa7ClhjbTX5e+5nn7uIeVfrOjJqYDlpiey95CEyvCep+tXNGK/HdqSooeWuAnbkia8QbzzE3/wQDqe+dVRgLrlqFRsKvse09p2UrfmczqAZIfoJVQEpe+tp5jW/ydYJn2VSwTzbcVSYuerWr/Ja5p3MrH2BPX/4Z9txooLOllEDaqo9RtaGb3HYMYHiT+mp5Wpw3p9Vk5A+gb/UL+LyAz/jzZ/XkDiu4APjFn/86zbiRSzdclfn5ev2cPLRW0kyrXhuWE1cXLztSCpMuZwC+ddQQiGXNvyR1tojtiNFNC13dV67H/kS0z172THvXymYt8x2HBXm4t0OmqZ+jHImsqz297Q1nLAdKWJpuatz2v30vzLv1DNsyLiVZR/7gu04KkIkxbo4NvnvOEk6xTW/o71Zp0gOBy131a89z97P7H338278JSz+/IO246gIk5oQw74Jn6KDWGYe/R2d7Xqmc7BpuasPKXv+AWbt/le2xi5l9r3PEBerV1hSwZeeksjWcXcQQxeTD/+e5vohXb1TnYPOlok2Jb8+92PGz+7NrzK7YR3vOBZSuexnVO4+/4XM9RId6kJkjU7lLe+dXFv7KCdW34D7q68Tn5xmO1ZE0C13BYDf66XsraeY3bCODbGX897y1cTE6swYNfyyszJZP/p2JnQf4vBDN9LV2W47UkTQLfcos+Vww4fu83d3MqriOYpMBc/GrCJ78mwKTj5vIZ2KVtnjxrM19/+xdPd32PPgjcz4x+dxxcTZjhXWdMs9ynW1N5NT/lum+A/zTPIdjJs6B4dDbMdSUWjpTV/hnYLvMKv9XcoevAVft15o+0JouUexjjO1FB16jFGmmRczPs+ECZMQ7XVl0cW3/hPvTP0Gc1rfZvd/3orf67UdKWxpuUepttojLKp+jHZi+EvO3eSOzbQdSSkALr79u7yT9xXmNb/Bzoc+hd/nsx0pLGm5R6HW4we4tPYJqhjH3ry7GJuWYjuSUh+w7M7v8874z7OgcR07HrpTt+CHQMs9yrRV7eTKM09TKgWcyP8kaUl60EqFHhFh2V0/4q/jPk1xw4vsfPDv8HbpWvCDIcbS2srFxcWmpKTEymtHI+P3s+PXX2PBsd/wjmMB/qkriHXrlZRUaDMGGg/vYGXHS5Q4ZtGVfwNLb/uW7VhWich2Y0zxQON0yz0K+L1edv7y71lw7De84boEmbZSi12FBREYPXk+LyfdzHzfXhLL/0Rbc6PtWGFByz3CeTrb2PWzG5lf/yIbx36ahPzLcelVlFSYyZhYxNrUWynyV1Dz4DU0nz5lO1LI0095BGtpaqDyJyuY17qRd6Z+g0u+8DOdw67CVub4fNal38GE7sM0PbScmvfKbEcKaVruEaruxFFOPXgV0zxlbJ33H1x8+3cRncSuwlxm9kQqVzxOkr+ZhMev4cC7a21HClla7hHoaGUZnjVXM85bzf7L17Bo1T22IykVNEVLVtJ6x3qaJJUpr9zO1ie/j/H7bccKOVruEWbf9o3EP3EdSaaV4x/9A7OX32I7klJBlzuliNSvbGBv4mIWlf+YnT+9kZamD6+bFM104bAI8tfnVjO/9Ls0O1Jove3P5E+bZzuSUkH3/gW3AfwTr+SVqnSubn6Z2p8U886Ymxidma0X20a33CNCW1sbG//z8yzb9S2q4qaHNHuWAAAIhklEQVQT/8W3Ga/FrqKAwyGMnjSPV7O/gBcXK2sfpr1iI63NuhWv5R7myko3U/PAUi49/TQ7x3yc/G+8SUrmONuxlBpR6emZVE2/i7fcl3KZZwMdP5nPjpcfjup98VruYaqhsZG3fnkvU5+7gXR/I+VXPMy8f3gEpzvGdjSlrIhxu0mYdjnrsu+hyTma+du+QcW/LWb/ppdsR7NClx8IM63t7Wx74VcUHvgFY2hgd9o1TL7955S9+Xvb0ZQKGX6/n9PHyyluep2x0sAumU5NxsWkZYxD+pzrEY775gNdfkAPqIaJ48erKX91DTOqnmA5pzkck8/R6x5l9twrbEdTKuQ4HA4ycwt4L3sqO6rLKG7dwJy6Rymvm0Bl8iKSx+YTE+O2HXNYablb9OSWo+d9vKO5gWltW4k5+CJz2zeRI17K42bTftlPmHzRKvTKGkqdn8vlIj1vDuXemWyrKaegZRPXtfyRtuZYStzFVO5cyNS5l0TkZ0nLPYR4fV48NftIqX6LaU2bmOU/gEv8nCGFfTkfZ9zyzzMtf4HtmEqFHbfLScaEGdT5CyhvrCOxfhcLu0tIeP4Gjrw4gZPjriZz0c1MnrkEcUTGociA9rmLyArg54ATeMQY88OzHpfex68D2oFPG2N2nO85o32fuzGG1x//Ee1Ntbhaa8jyHCXfHCZV2gB4j1yOxRcwZdpMxufkDviG6+/C10qpc+v0dBHndpL43ovM8OzBKYYayeJo1hUkFl7LpHnLSUpJsx3zQ4K2z11EnMBDwNVANbBNRF4wxuzrM2wlkN/7tRj4Ve9/o1pHRycNdTU0n66hveEEnbXvQX05SS2HGNt1hKvpWbrUb4QqGUdZzBw6EnOIGz0ed3wy8UBu7mi7P4RSESouNqb3gOo3OV17nMq3/0hs5cvMP/lHYk49he9NodI1mfr0BcROWkrGlHmMzZuBOybWdvSABLJbZhFQaYw5BCAiTwGrgL7lvgr4ren5NeBdEUkVkWxjzImgJ+6H8fvx+bz4/X78fh/G78fv8+Lt9uDpaKers41uTztdHW14uzrxetrxdbXj7+ro/WqH7g5Mdyd4O3B4OxBvJw5fJw6fB5evE4e/C39XB058CAYH/v/5Mn6cfW47xRBnPKRKKzlATp+sbcRxwj2BmrTF7OlyQmIWCalZON2xuIDks3423SJXavilZ+WQfvNXga/S1nKGgzvforX8bVJqtzH31HPE1T4NW6DbODnmGMPp+Il0Jk/EJGbhTEzHnZxB/Kgs4lJG43LH4Y6JxRkThzs2ntjYWNwx8ThdI7sXPJBXywGO9bldzYe3yvsbkwMEvdzLXn+cqW9/7QNFKgTv4EGXcdEpsXiIoUti6JZYuhxx+MSNX4RuYjAIfnH0VrxgcOKXnso34iA2OR3jisOfkIkrOYuYUWNISMtmdM4U0sbmMbX34E3f06iVUqEhMTmVWZfeCJfeCIDH00F52Raaju7FW1tObNNh0jqqmN62nXjpCvh5fUbw48Ag7J54J8V3/WS4fgQgsE7s7zDy2TvqAxmDiNwN3N17s1VEDgbw+hciA6gf5tcYKs02NJptaDRbv74x0IBhyvZT+OxPh/qHJwYyKJByrwZy+9weD9QMYQzGmDXAmkCCBYOIlARy4MEGzTY0mm1oNNvQhHK2gQQy52cbkC8ik0QkBrgVeOGsMS8Afy89LgKaRmp/u1JKqQ8bcMvdGOMVkS8D6+mZCvmYMaZMRO7pfXw1sJaeaZCV9EyF/MzwRVZKKTWQgI5DGmPW0lPgfe9b3ed7A3wpuNGCYsR2AQ2BZhsazTY0mm1oQjnbeVlbOEwppdTwiYzzbJVSSn1AVJS7iHxDRIyIZNjO8j4R+b6I7BaRUhF5VURC6gobInK/iBzozficiKTazvQ+Efm4iJSJiF9ErM9kEJEVInJQRCpF5Nu28/QlIo+JSK2I7LWd5Wwikisib4nI/t7/n1+1nel9IhInIltFZFdvtn+xnWmwIr7cRSSXnqUTzr8E48i73xgz2xgzF3gJ+J7tQGd5DZhpjJkNlAP/23KevvYCNwEbbQfpszzHSqAQuE1ECu2m+oDfACtshzgHL/B1Y8wM4CLgSyH0d+cBrjDGzAHmAit6ZwKGjYgvd+CnwDfp56Qqm4wxzX1uJhJ6+V41xnh7b75Lz7kLIcEYs98YM9wnwAXqb8tzGGO6gPeX5wgJxpiNQEiuYWGMOfH+AoPGmBZgPx9crcMa06O196a79yukPqMDiehyF5GPAseNMbtsZ+mPiPxARI4BnyL0ttz7ugt4xXaIEHWupTfUIIhIHjAP2GI3yf8QEaeIlAK1wGvGmJDJFoiwX89dRF4Hxvbz0H3A/wGuGdlE/+N82Ywxzxtj7gPuE5H/DXwZ+L+hlK93zH30/Pr8u1DLFiICWnpDnZuIJAF/Av7xrN9orTLG+IC5vcebnhORmcaYkDt2cS5hX+7GmKv6u19EZgGTgF09y80zHtghIouMMSdtZuvHk8DLjHC5D5RPRO4ErgeuNCM8Z3YQf3e2BbT0huqfiLjpKfbfGWOetZ2nP8aYMyLyF3qOXYRNuUfsbhljzB5jTJYxJs8Yk0fPh3D+SBX7QEQkv8/NjwIHbGXpT+8FWr4FfNQY0247TwgLZHkO1Y/ei/w8Cuw3xgzvEomDJCKZ788QE5F44CpC7DM6kIgt9zDwQxHZKyK76dl1FDLTwHr9gp7l5V/rna65eqA/MFJE5EYRqQaWAC+LyHpbWXoPOr+/PMd+4GljTJmtPGcTkd8Dm4HpIlItIp+1namPZcAdwBW977FSEbnOdqhe2cBbvZ/PbfTsc3/JcqZB0TNUlVIqAumWu1JKRSAtd6WUikBa7kopFYG03JVSKgJpuSulVATScldKqQik5a6UUhFIy10ppSLQ/wcZbgQk0/SoUAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import seaborn as sns\n",
    "sns.distplot(trace)\n",
    "sns.distplot(trace3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There still seems to be a problem here where in the PyMC4 implementation, the step_size keeps getting smaller and smaller, causing the sampler to take very long. Haven't figured it out yet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.2951355636865856"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hmc.step_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.2951356724293117"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hmc3.step_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.97152436])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hmc.potential._stds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.97152434])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hmc3.potential._stds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
